package graybox;

import java.io.IOException;

/**
*@author tonychao 
*功能 RRS算法的一个实现，用于实现对于参数空间的搜索 参考原作者论文的伪代码写出来的
*/


import java.util.ArrayList;
import org.json.JSONException;
import confs.SingleParameter;
import fetcher.SparkConfiguration;
import others.Global;
import shellInterface.UserSubmitInterface_test;



/*这个类是对一个ML预测模型的封装，输入的是学习的成果，返回的是一个函数的最有配偶之采纳数
 * 
 * 只针对同一应用进行建模，所以只处理Run Record
 * 
 * 
 * 
*/
public class ParameterSearch {
	
	public static int callCount=0;
	public static int exploitCount=0;

	public CallPythonWindows h=null;
	
	public ParameterSearch() {
		
	}

	/*
	 * 初始化环境
	 * */
	public int init(String app_md5,String className,String pythonSaemonPath) throws InterruptedException{
		//单线程模式
		this.h =new CallPythonWindows();
		int retValue=this.h.pythonDaemonStart(pythonSaemonPath,new String[]{app_md5,className});
		return retValue;
	}
	
	/*
	 * 结束，销毁进程
	 * */
	public void end(){
		this.h.pythonDaemonEnd();
	}
	
	//使用了先验信息的搜索
	public SparkConfiguration searchRRS(ArrayList<String> groupNameList) throws IOException, JSONException{
		ArrayList<SparkConfiguration> searchedConfigurationList=new ArrayList<SparkConfiguration>();
		SparkConfiguration tmPara= this.searchRRS(searchedConfigurationList,groupNameList);
		return tmPara;
	}
	
	
	//RRS 搜索
	/*//JAVA 实现算法，使用 getPredicateValueInMS(TurningPara conf)
	 * 有先验信息，good
	 * 
	 * */
	public SparkConfiguration searchRRS(ArrayList<SparkConfiguration> listlist,ArrayList<String> groupNameList) throws IOException, JSONException{
		
		ArrayList<Integer>F =new ArrayList<Integer>();
		//1 Initialize exploration parameters p, r, sampleCount1 ← ln(1 − p)/ ln(1 − r) ;
		double p=0.99,r=0.01;//置信区间p，收缩度R,这是N=44
		int sampleCount1 = (int)(Math.log(1-p)/Math.log(1-r));//随机采样点次数
		//2 Initialize exploitation parameters q, v, c, st, sampleCount2 ← ln(1 − q)/ ln(1 − v);
		double q = 0.9,v=0.1;//置信区间q,目标收缩区间分位数v
		double c= 0.25;//二次选择收缩度
		double shrinkRatio1=0.02;
		double shrinkRatio2=0.1;
		int sampleCount2= (int)(Math.log(1-q)/Math.log(1-v));//二次收缩所需要的抽样次数
		System.out.println("sampleCount1\t"+sampleCount1);
		System.out.println("sampleCount2\t"+sampleCount2);
		//3 Take n random samples xi, i = 1 ...n from parameter space D;
		ArrayList<SparkConfiguration> paraList= new ArrayList<SparkConfiguration>();
		ArrayList<Integer> timeList=new ArrayList<Integer>();
		int totalCount=sampleCount1;//预先刷新
		
		SparkConfiguration defaultConfguration= new SparkConfiguration();
		long defualtTime = f(defaultConfguration, groupNameList);
		
		ParameterDistrict S= new ParameterDistrict(groupNameList);
		for (int i=0;i<sampleCount1;i++){
			SparkConfiguration tmPara=S.randomSparkConfiguration(groupNameList);
			paraList.add(tmPara);
			timeList.add(f(tmPara,groupNameList));
		}
		
		//4 x0 ← arg min1≤i≤n(f(xi)), staticBaseLine ← f(x0), add f(x0) to the threshold set F;
		int min=Integer.MAX_VALUE;
		for (int i :timeList){
			if (i<min) min=i;
		}
		int index = timeList.indexOf(min);
		SparkConfiguration x0=paraList.get(index);
		int staticBaseLine=min;
		//add f(x0) to the threshold set F
		F.add(min);
		//5 i ← 0, exploit f lag ← 1, xopt ← x0;
		int i =0;
		boolean exploitFlag=true;
		SparkConfiguration Xopt=x0;
		
		long startime=System.currentTimeMillis();
		//6 终止条件：搜索时间超过MAX_SEARCH_TIME_MS，或者是搜索次数超过RANDOM_SEARCH_TIME
		while(totalCount<Global.RANDOM_SEARCH_COUNT&&System.currentTimeMillis()-startime<=Global.MAX_SEARCH_TIME_MS){
			if (totalCount%500==0) {
				double searchCountRate=(totalCount/(double)Global.RANDOM_SEARCH_COUNT);
				double timeRate=(System.currentTimeMillis()-startime)/(double) Global.MAX_SEARCH_TIME_MS;
				int completeRatio=(int)( (1-(1-searchCountRate)*(1-timeRate))*100);
						
				UserSubmitInterface_test.UIOutPut("正在搜索优化参数..........."+completeRatio+"%");
				//System.out.println(totalCount);
			}

			// Exploit flag is set, start exploitation process
			if(exploitFlag){
				//8 j ← 0, fc ← f(x0), xLocalCenter ← x0, P ← r;
				int j=0;
				int fc= f(x0,groupNameList);
				SparkConfiguration xLocalCenter=x0;
				double P=r;
				ParameterDistrict district= new ParameterDistrict(groupNameList);//每一个全局搜索 explore 到的点都要从新开始进行exploit district 在此过程中不断被缩小

				
				boolean smallEnough=false;
				ParameterDistrict tmpdistrict=new ParameterDistrict(groupNameList);
				//收缩比例按什么确定
				smallEnough=!tmpdistrict.shrink(xLocalCenter, shrinkRatio1,groupNameList);
				
				
				while(!smallEnough){//开始搜索循环，直到搜索区域足够小
					//10 Take a random sample x' from ND,P(xLocalCenter); 为了节省开销，tmpdistrict只有当district或者xLocalCenter 变化时才变化
					SparkConfiguration x_=tmpdistrict.randomSparkConfiguration(groupNameList);
					//11 if f(x') < fc then
					if (f(x_,groupNameList)<fc){//找到了一个新的对齐点，将整个搜索位置进行平移
						// Find a better point, re-align the center of sample space to the new point
						//12 xl ← x', fc ← f(x');
						//13 j ← 0;
						xLocalCenter=x_;
						fc=f(x_,groupNameList);
						j=0;
						
						//
						 tmpdistrict=(ParameterDistrict) district.clone();
						 smallEnough=!tmpdistrict.shrink(xLocalCenter, 0.1,groupNameList);
						 if (smallEnough) break;
					}
					//else
					else{
						//14 j ← j + 1;
						j++;
					}
					//15  如果经过了足够抽样仍然没有找到更好的点，按概率将应该很难有更好的了，则应该继续向下搜索
					if(j>=sampleCount2){
						// Fail to find a better point, shrink the sample space
						//16 P ← c · P, j ← 0;,c 是收缩率，但是P是啥概率真是不得而知，st可能是一个将空间进行范围约束的变量，当P越来越小，到一定程度就不在向下继续挖了
						smallEnough=!district.shrink(xLocalCenter, shrinkRatio2,groupNameList); //当区域够小的时候，会拒绝收缩并且返回false,此时这一轮搜索会结束，因为 一共只有100个参数点，已经从里面抽了不少，不用再接着找了
						if (smallEnough) break;
						tmpdistrict=(ParameterDistrict) district.clone();
						tmpdistrict.shrink(xLocalCenter, c,groupNameList);
						if(smallEnough){
							//System.out.println("");
						}
						P=c*P;
						j=0;
			
					}
				}
				
				//17 exploit 结束，说明已经在某子空间搜索到了一个足够好的点，此时更新全局最优解;
				exploitFlag=false;
				if (f(xLocalCenter,groupNameList)<f(Xopt,groupNameList)) Xopt=xLocalCenter;
				exploitCount++;
			}
			
			//这里除了模型，实际上还应该有先验知识的作用
			//18 Take a random sample x0 from S，因为我们有实验数据，在实验数据的最优解被使用之前，不应该直接去找一个随机数;
			//但是在使用API的版本中若想获得这个数据需要将所有的数据库值都拿出来排序然后才能够获得
			if (listlist.size()>0){
				x0=listlist.get(0);
				listlist.remove(0);
				if (Global.DEBUG_FLAG) System.out.println(x0.toString());
			}
			else {
				x0=S.randomSparkConfiguration(groupNameList);
			}
			int f_x0=f(x0,groupNameList);
			timeList.add(f_x0);
			
			
			//19 if f(x0) < staticBaseLine then 我们根据统计数据认为这是一个有钱途的点，将会拓展它
			if(f_x0<staticBaseLine){
				exploitFlag=true;
			}

			//21 if i = n then
			if (i >=sampleCount1){
				//				// Update the exploitation threshold every sampleCount1 samples in the parameter space
				//				22 Add min1≤i≤n(f(xi)) to the threshold set F;
				min=Integer.MAX_VALUE;
				for (int k :timeList){
					if (k<min) min=k;
				}
				F.add(min);
				//				23 staticBaseLine ← mean(F), i ← 0;
				int sum =0;
				for (int count :F){
					sum+=count;
				}
				staticBaseLine= sum/F.size();
				i=0;
				timeList.clear();
			}
			//24 i ← i + 1;
			i++;
			totalCount++;
		}
		
		for (String t:Xopt.parameterMap.keySet()){
			SingleParameter tmp=Xopt.parameterMap.get(t);
			if (Global.DEBUG_FLAG) System.out.println(tmp.name+tmp.value); 

		}
		
		if (Global.DEBUG_FLAG) System.out.println("找到最优参数：\t"+Xopt.toString());
		if (Global.DEBUG_FLAG) System.out.println("最优参数预计优化比：\t"+f(Xopt,groupNameList)/defualtTime+"ms");
		//System.out.println("找到最优参数：\t"+Xopt.toString());
		//System.out.println("最优参数预计优化比：\t"+f(Xopt,groupName)/(double)defualtTime+"ms");
		this.end();
		return Xopt;
	}
	
	public int getThreadReturnValue() throws InterruptedException{
		return this.h.daemonProcess.waitFor();
	}
	
	
	private Integer f(SparkConfiguration tmPara, ArrayList<String> groupNameList) {
		callCount++;//System.out.println(callCount);
		
		//调用预测函数
		if (this.h!=null) {
			 int f_x=-1;
			try {
					f_x = this.h.pythonDaemonPredict(tmPara,groupNameList);
					tmPara.f_x=f_x; 
				} catch (Exception e) {
					e.printStackTrace();
					System.exit(1);
				}
			 if (f_x==-1) {//默认值，说明预器异常,在标准化之后没有意义
//				 System.err.println("预测模块异常，通常是输入错误引起python的问题");
//				 System.exit(1);
			 }
			 return f_x;
		}
		System.err.println("预测模块没启动");
		return null;
	}
	
	
	//使用类似梯度下降的方法
	public SparkConfiguration search2(){
		return null;
		
	}



	
}
